Tri Dao¹ and Jay Shah²
¹ Together AI / Princeton University, [email protected]
² Colfax Research, jayhshah@colfaxhttp://-intl.com
FlashAttention-2 在 A100 GPU 上经过高度优化,达到了 70% 的利用率。然而,在 H100 GPU 上,FA-2 的利用率仅为 35-40%。下图展示了在 H100 80GB SXM5 上,随着序列长度的增加,PyTorch、FlashAttention 和 FlashAttention-2 在 Attention 前向传播速度上的对比。FlashAttention-2 的性能远未达到 H100 的理论峰值。
FlashAttention-3 旨在解决上述挑战,主要通过以下三个方面进行优化,最终实现了 1.6-3倍 的加速。
H100 上的新指令:
异步性 (Asynchrony):
低精度 (Low-precision):
H100 架构引入的 WGMMA 和 TMA 是性能提升的关键。
mma.sync 指令只能达到峰值吞吞吐量的 2/3。一个关键特性是,WGMMA 和 TMA 都是异步指令:线程在发出这些指令后,可以继续执行其他工作,而指令则在后台执行。
WGMMA 和 TMA 被整合到一个生产者-消费者(producer-consumer)模型的 warp 专用化流水线设计中。
关键设计:
* 使用 warpgroup 范围的寄存器重分配,让消费者 warp 获得更多的寄存器份额。
* 使用共享内存(SMEM)双缓冲技术,将 K 和 V 的加载操作流水线化。
* 使用基于共享内存的屏障(SMEM-resident barriers)和线程本地的流水线状态对象来同步生产者和消费者。
* GEMM0 (Q * K^T) 为 SS 布局(Shared input, Shared output),GEMM1 (P * V) 为 RS 布局(Row input, Shared output)。
为何需要重叠?
* 专用函数单元(Special Function Units, SFU)的吞吐量远低于张量核心(Tensor Cores)。SFU 用于计算 softmax 中的指数函数(exp)。
示例:
* 假设 headdim 为 128,块大小为 128 x 192。
* FP16 WGMMA: 2 x 2 x 128 x 192 x 128 = 12.6 MFLOPS。在 4096 FLOPS/cycle 的速率下,需要 3072 个周期。
* MUFU.EX2 (用于 exp): 128 x 192 = 24.6k OPS。在 16 OPS/cycle 的速率下,需要 1536 个周期。
结论是,MUFU.EX2 的执行时间占到了 WGMMA 的 50%。在 FP8 模式下情况更糟,两者都需要 1536 个周期。因此,我们希望在张量核心忙于 WGMMA 计算的同时,执行 EX2(softmax)操作。
一个简单的解决方案是依赖 warp 调度器自行处理,这在一定程度上有效,但我们可以做得更好。
通过使用同步屏障(bar.sync)实现乒乓调度(Pingpong scheduling),可以更主动地管理和重叠不同 warpgroup 的计算任务,从而提升性能。
在每个 warpgroup 内部,可以进一步利用 WGMMA 的异步特性。
* 将第 k 次迭代的 GEMM1 计算与第 k+1 次迭代的 softmax 计算进行重叠。
* 这种方法会使用更多寄存器,因为下一次 GEMM0 的累加器和当前 GEMM1 的操作数需要同时存在。
通过两阶段的 warpgroup 内重叠,性能可以进一步提升。
* 性能提升: 从 640 TFLOPS 提升到 670 TFLOPS。
我们将乒乓调度和warpgroup内重叠相结合,定义了 FA-3 的计算路径。
bar.sync 和 bar.arrive 以及两个命名屏障来实现乒乓调度。WGMMAX 涉及一组 wgmma.mma_async 指令,这些指令被作为一个组提交。FP8 张量核心(Tensor Cores)可以将 WGMMA 的吞吐量翻倍,但代价是精度的损失。
为了解决 FP8 带来的精度损失问题,特别是异常值(outliers)导致的量化误差,可以采用非相干处理技术。
* 方法: 将 Q 和 K 乘以一个随机正交矩阵,以“分散”异常值。
* 原理: 对于正交矩阵 J (即 J * J^T = I),S = Q * K^T = (QJ)(KJ)^T。
* 效果: 在包含 0.1% 大幅值条目(模拟异常值)的正态分布 QKV 数据上,该方法可将量化误差降低 2.6倍。
下表对比了不同方法下的均方根误差(RMSE):
LDSM/STSM 指令和字节置换(byte permute)。思路: 将物理 CTA(Cooperative Thread Arrays)与逻辑工作块(work tiles)解耦。启动与 SM(Streaming Multiprocessors)数量相等的固定数量的 CTA。这些 CTA 在处理多个工作块期间是持久的。
示例:
Seqlen = 4096, Heads = 8, Batch = 4。 Fix BlockM = 128,so mblocks = 32。32*8*4 = 1024 个工作块需要处理。[1024/132] = 8,每个 CTA 在其生命周期内将运行 7 或 8 个工作块。CUTLASS 核函数是使用三个主要类组合构建的:
由于这种组合式设计,将核函数更改为持久化核函数变得非常简单:只需编写一个不同的 TileScheduler 即可。同时,在加载和 mma 方法中也需要额外的屏障逻辑。
在使用因果掩码(causal masking)时,工作块(work tiles)的主循环(mainloops)迭代次数不同。通过使用最长处理时间优先(Longest-processing-time-first, LPT)算法(Graham, 1969)来进行负载均衡。
这种优化将因果注意力(causal attention)的速度从 670 TFLOPS 提升到 710 TFLOPS。
在解码(decoding)过程中,查询(query)长度很短(通常只有几个 token),而上下文(context)长度很长(例如,128k)。
从 FlashAttention-2 (FA-2) 开始,引入了 Flash Decoding:沿着 KV 序列长度进行拆分,以便为 GPU 提供足够的工作量来充分利用其计算能力。
WGMMA(Warp-Group Matrix Multiply Accumulate)块在 M 维度上的宽度是 64。对于短查询长度来说,这会造成浪费。然而,我们可以通过打包多个查询头来填满 WGMMA 块,这适用于 MQA/GQA(多查询注意力/分组查询注意力)。
DeepSeek 的 MLA 具有很大的头维度(head dim),为 576 / 512。标准的拆分方法没有足够的寄存器。
为此,采用了 Warp 专职化(Warp specialization)策略:
- WG1 (工作组1):同时执行 QK 矩阵乘法和 PV 矩阵乘法。每个线程需要 160 个累加寄存器。
- WG2 (工作组2):仅执行 PV 矩阵乘法。每个线程需要 128 个累加寄存器。
下图展示了 MLA 在批量大小为 128、查询头为 128 的解码速度测试(在 H100 80GB SXM5 上进行)。
seqlen_q = 1(即解码单个 token),也已经达到了计算密集型(compute-bound)的状态。下图展示了在 H100 80GB SXM5 上,头维度为 128 时的前向注意力计算速度。
下图展示了在 H100 80GB SXM5 上,头维度为 256 时的前向注意力计算速度。
下图展示了在 H100 80GB SXM5 上,头维度为 256 时的 FP8 前向注意力计算速度。
下图展示了在 H100 80GB PCIe 上,使用 BF16 精度、头维度为 128、MQA 16、查询序列长度为 4 的解码性能。